# Algorithm 0:
from itertools import combinations, chain
import collections
import numpy as np
from Environment.Environments.Pusher1D.pusher1D import Pusher1D
from Environment.Environments.ACDomains.Domains.forest_fire import ForestFire
from Environment.Environments.ACDomains.Domains.rock_throwing import RockThrowing
from Environment.Environments.ACDomains.Domains.gang_shoot import GangShoot
from Environment.Environments.ACDomains.Domains.halt_charge import HaltCharge
from Environment.Environments.ACDomains.Domains.train import Train
from Environment.Environments.ACDomains.Domains.voting import Voting
from Environment.Environments.ACDomains.Domains.mod_DAG import ModDAG
import sys, os, copy
import time
from math import comb


def get_all_subsets(n):
    # returns all combinations up to length n, including the empty set
    return list(chain(*[combinations(range(n), ni) for ni in range(n+1)]))

def hash_vector(vals): # handles state values up to 10
    tv = 0
    for i, v in enumerate(vals):
        tv += v * np.power(10, i)
    return tv

def compute_possible(environment):
    # a binary includes the object, or does not
    all_binaries = np.array(np.meshgrid(*[[0,1] for i in range(environment.num_objects)])).T.reshape(-1,environment.num_objects)
    all_states = environment.all_states
    outcomes = environment.outcomes
    
    
    # get all valid binary assignments, a binary is invalid if the same input maps to different outcomes
    subsets = get_all_subsets(len(all_states))
    # all_state_combinations = np.array(np.meshgrid(*[[0,1] for i in range(environment.num_objects)])).T.reshape(-1,environment.num_objects)
    valid_subsets = list()
    # try out every binary
    for binary in all_binaries:
        bin_valid = list()
        # try out every subset of states
        for subset in subsets:
            outcome_map = dict()
            valid = True
            inval = None
            for i in subset:
                inval = hash_vector(all_states[i] * binary)
                if inval in outcome_map:
                    if outcome_map[inval] != outcomes[i]:
                        valid=False
                        break
                else:
                    outcome_map[inval] = outcomes[i]
            if valid: print(binary, valid, outcome_map, subset, inval, )
            if valid:
                bin_valid.append(subset)
        valid_subsets.append(bin_valid)
    # print(valid_subsets, len(valid_subsets), [len(vss) for vss in valid_subsets])
    subset_indices = [np.arange(len(vss)) for vss in valid_subsets]
    # print(subset_indices)
    all_combinations = np.array(np.meshgrid(*valid_subsets)).T.reshape(-1,len(valid_subsets))
    # all_combinations = np.array(np.meshgrid(*subset_indices)).T.reshape(-1,len(valid_subsets))
    # print(all_combinations)

    def check_valid(comb, num): # checks if a combination of binary assignments is valid
        # print(comb)
        for i in range(len(comb)-1):
            for j in range(i+1,len(comb)):
                # print(i, j, comb[i], comb[j], all(np.isin(np.array(comb[i]),np.array(comb[j]))))
                if any(np.isin(comb[i],comb[j])): # overlapping subsets are invalid
                    return False
        all_covered_states = list(set(np.array(comb).flatten()))
        all_covered_states = sum([list(ac) for ac in all_covered_states if len(ac) > 0], start=list())
        all_covered_states.sort()
        # if comb[0] == (0,2,3): print(comb, all_covered_states)
        # print(all_covered_states)
        return len(all_covered_states) == len(all_states)


    valid_combinations = list()
    for i, comb in enumerate(all_combinations):
        start = time.time()
        if check_valid(comb, len(all_states)):
            valid_combinations.append(comb)
            # print(i, len(valid_combinations), len(all_combinations))
    cost = list()
    for valid_combination in valid_combinations:
        cost.append(np.sum(np.array([np.sum(bin) for bin in all_binaries]) * np.array([len(c) for c in valid_combination])))
        # print(valid_combination, cost[-1])
    min_cost = min(cost)
    # print("min cost combinations")
    # for valid_combination, c in zip(valid_combinations, cost):
    #     if c == min_cost:
    #         print(valid_combination, c)

def partition(list_, k):
    """
    partition([1, 2, 3, 4], 2) -> [[1], [2, 3, 4]], [[1, 2], [3, 4]], ..., [[1, 3], [2, 4]], [[3], [1, 2, 4]]
    """
    if k == 1:  # base case 1: k == 1, just yield itself as a list
        yield [list_]
    elif k == len(list_):  # base case 2: k == len(list_), yield each item in the list wrapped in a list
        yield [[s] for s in list_]
    else:
        head, *tail = list_  # head = the first element, tail = the rest
        for p in partition(tail, k-1):  # case 1: head -> 1, partition(tail, k-1) -> k-1.
                                        # head + partition(tail, k-1) -> 1+k-1 -> k
            yield [[head], *p]
        for p in partition(tail, k):  # case 2: head -> 1, partition(tail, k) -> k.
                                      # bloat x to [e1, e2, e3] -> [[x+e1, e2, e3], [e1, x+e2, e3], [e1, e2, x+e3]]
            for i in range(len(p)):
                yield p[:i] + [[head] + p[i]] + p[i+1:]  # bloat head to partition(tail, k) -> k

def get_all_disjoint_sets(iterable):
    l = list(iterable)
    return chain.from_iterable(list(partition(l, i)) for i in range(1, len(l)+1))

def yield_all_disjoint_sets(iterable):
    l = list(iterable)
    for k in range(1,len(l)):
        if k == 1:  # base case 1: k == 1, just yield itself as a list
            yield [l]
        elif k == len(l):  # base case 2: k == len(list_), yield each item in the list wrapped in a list
            yield [[s] for s in l]
        else:
            head, *tail = l  # head = the first element, tail = the rest
            for p in partition(tail, k-1):  # case 1: head -> 1, partition(tail, k-1) -> k-1.
                                            # head + partition(tail, k-1) -> 1+k-1 -> k
                yield [[head], *p]
            for p in partition(tail, k):  # case 2: head -> 1, partition(tail, k) -> k.
                                        # bloat x to [e1, e2, e3] -> [[x+e1, e2, e3], [e1, x+e2, e3], [e1, e2, x+e3]]
                for i in range(len(p)):
                    yield p[:i] + [[head] + p[i]] + p[i+1:]  # bloat head to partition(tail, k) -> k

global counter

def binary_state_compatibility(all_binaries, all_states, all_outcomes, environment, use_witness=False):
    # returns the compatibility (measure of necessity) between the
    # binary and every other state
    cost = 0
    compatibility = dict()
    for i, binary in enumerate(all_binaries):
        compatibility[i] = list()
        for j, (state, outcome) in enumerate(zip(all_states, all_outcomes)):
            pos_comp, neg_comp, cf_cost = environment.evaluate_split_counterfactuals(binary, state, outcome, use_witness=use_witness)
            compatibility[i].append((pos_comp, neg_comp))
            cost += cf_cost
    # print(cost)
    return compatibility

def bell_number(n):
    bell_numbers = [1]
    for i in range(1,n+1):
        next_bell_number = 0
        for j, b in enumerate(bell_numbers):
            # print(i,j)
            next_bell_number += comb(i,j+1) * b
        bell_numbers.append(next_bell_number)
    return bell_numbers[-1]

def check_subset(b, mb):
    # checks if b is a subset of mb, assuming two same length binary vectors
    b = -b
    fp = mb + b
    fp[fp > 0] = 0
    fn = mb + b
    fn[fn < 0] = 0
    return np.sum(fn) > 0 and np.sum(fp) == 0

def check_disjoint(b, mb):
    b = -b
    fp = mb + b
    fp[fp <= 0] = 0
    fn = mb + b
    fn[fn > 0] = 0
    return np.sum(fp) == np.sum(b) and np.sum(fn) == np.sum(mb)

def check_smaller(b, minimal_subsets):
    new_minimal_subsets = np.array(copy.deepcopy(minimal_subsets))
    keep = list()
    for idx, mb in enumerate(minimal_subsets):
        strict_subset = check_subset(b, mb)
        if not strict_subset:
            keep.append(idx)
    return new_minimal_subsets[keep]

def state_minimality(s, binaries, compatibility, one_constant, zero_constant):
    if one_constant < 0: return binaries # don't use compatibility if constant negative
    valid_binaries = list()
    for i, b in enumerate(binaries): # binaries identified by index
        compatible = compatibility[i]
        subset_compatible = True
        if compatible[s][0] < one_constant or compatible[s][1] > zero_constant:
            subset_compatible = False
            continue
        if subset_compatible: valid_binaries.append(b)
    
    strict_minimal_subsets = copy.deepcopy(valid_binaries)

    for b in valid_binaries:
        strict_minimal_subsets = check_smaller(b, strict_minimal_subsets)
    return valid_binaries, strict_minimal_subsets

def create_minimum_binary_table_string(subset_binary_dict, all_states):
    table_string = ""
    for s in all_states:
        binaries = subset_binary_dict[tuple(s)]
        table_string += "|".join([" ".join([str(bv) for bv in b]) for b in binaries]) + "-"
        table_string += ",".join([str(sv) for sv in s])
        table_string += ";"
    return table_string[:-1]

def compute_possible_efficient(environment, one_constant, zero_constant, use_witness= False, save_path="", use_invariant=True, use_zero=False):
    # a binary includes the object, or does not
    all_binaries = np.array(np.meshgrid(*[[0,1] for i in range(environment.num_objects)])).T.reshape(-1,environment.num_objects)
    use_zero = use_zero
    if not use_zero: all_binaries = all_binaries[1:]
    all_states = environment.all_states
    outcomes = environment.outcomes
    passive_mask = environment.passive_mask
    all_subsets = get_all_subsets(len(all_states))
    
    # print(np.concatenate([np.array(all_states), np.expand_dims(np.array(outcomes), axis=-1)], axis=-1))
    # describes which binaries are allowed in a particular state
    compatibility = binary_state_compatibility(all_binaries, all_states, outcomes, environment, use_witness=use_witness)
    # for k in compatibility.keys():
    #     for s, c in enumerate(compatibility[k]):
    #         print(all_binaries[int(k)], all_states[s], c)
    minimal_assignments = dict()
    for sidx in range(len(all_states)):
        print("state", all_states[sidx])
        minimal_assignments[tuple(all_states[sidx].tolist())] = state_minimality(sidx, all_binaries, compatibility, one_constant, zero_constant)[1]
    minimal_binary_table_string = create_minimum_binary_table_string(minimal_assignments, all_states)
    print(minimal_binary_table_string, minimal_assignments)

    def check_valid(subset, binaries):
        # check with binaries are compatible with the given subset
        valid_binaries = list()
        for i, binary in enumerate(all_binaries):
            # check if binary is compatible with the following subset
            binary_check = dict()
            invalid = False
            for s in subset:
                # state of the form [k_0, \hdots, k_n] where n is the number of factors, k_i is the number of discrete values factor i can
                factored_state = all_states[s]
                # convert to a tuple where the binary sets certain values to 0
                masked_factored_state = tuple((factored_state * binary).tolist())
                if masked_factored_state in binary_check:
                    if binary_check[masked_factored_state] != outcomes[s]:
                        # invalid if the same masked state has at least two different outcomes
                        invalid = True
                        break
                else: # assign the masked state
                    binary_check[masked_factored_state] = outcomes[s]
            if not invalid: # append the index of the binary
                valid_binaries.append(i)
        return valid_binaries
    
    def check_compatible(subset, binaries, compatibility, one_constant, zero_constant):
        if one_constant < 0: return binaries # don't use compatibility if constant negative
        valid_binaries = list()
        for i in binaries:# binaries identified by index
            compatible = compatibility[i]
            subset_compatible = True
            for s in subset:
                if compatible[s][0] < one_constant or compatible[s][1] > zero_constant:
                    subset_compatible = False
                    break
            if subset_compatible: valid_binaries.append(i)
        return valid_binaries
    # create a mapping of every subset to each of its valid binaries
    subset_binary = dict()
    subset_index_mapping = dict()
    print("num subsets", len(list(all_subsets)))
    print("num binaries", len(all_binaries))
    for i, subset in enumerate(all_subsets):
        valid_binaries = check_valid(subset, all_binaries) if use_invariant else np.arange(len(all_binaries))
        subset_binary[i] = check_compatible(subset, valid_binaries, compatibility, one_constant, zero_constant)
        subset_index_mapping[tuple(subset)] = i
    # create all disjoint, complete partitionings of the subsets
    disjoint_sets = yield_all_disjoint_sets(range(len(all_states)))
    # print("total djs", bell_number(len(all_states)))
    # print("num disjoint", len(disjoint_sets))
    # print("num valid binaries", len(list(subset_binary.keys())))

    def all_assignments(disjoint_subset, unusable_binaries): # returns all mappings of unusable binaries to a particular disjoint subset
        if len(disjoint_subset) == 1:
            subset_valid = subset_binary[subset_index_mapping[tuple(disjoint_subset[0])]]
            usable_binaries = set(subset_valid) - set(unusable_binaries)
            return [[ub] for ub in usable_binaries], [[subset_index_mapping[tuple(disjoint_subset[0])]] for _ in usable_binaries], len(usable_binaries)
        else:
            counter = 0
            subset_valid = subset_binary[subset_index_mapping[tuple(disjoint_subset[0])]]
            usable_binaries = set(subset_valid) - set(unusable_binaries)
            # print("ss", subset_valid, usable_binaries, unusable_binaries)
            binary_assn = list()
            subset_assn = list()
            for b in usable_binaries: # append each usable bin to the front
                rem_bin, rem_subset, ctr = all_assignments(disjoint_subset[1:], unusable_binaries + [b])
                counter += ctr
                for rb, rs in zip(rem_bin, rem_subset):
                    counter += 1
                    binary_assn.append([b] + rb)
                    subset_assn.append([subset_index_mapping[tuple(disjoint_subset[0])]] + rs)
                    # print([b] + assign)
            return binary_assn, subset_assn, counter

    # for each disjoint set, find all valid assignments of binary to set
    assigned_binaries = list()
    assigned_subsets = list()
    counter = 0
    for idx, ds in enumerate(disjoint_sets):
        ab, asub, count  = all_assignments(ds, [])
        counter += count
        assigned_binaries += ab
        assigned_subsets += asub
        if idx % 1000000 == 0: print(idx, count)
    # print(idx)
    print("cost counter, number of assigned binaries", counter, len(assigned_binaries))
    
    cost = list()
    for ab, asub in zip(assigned_binaries, assigned_subsets):
        # print(ab, convert_subset(ab, all_binaries), convert_subset(asub, all_subsets), np.sum(np.sum(convert_subset(ab, all_binaries), axis=-1) * np.array([len(all_subsets[c]) for c in asub])))
        cost.append(np.sum(np.array(np.sum(np.abs(convert_subset(ab, all_binaries) - passive_mask), axis=-1)) * np.array([len(all_subsets[c]) for c in asub])))
        # print(valid_combination, cost[-1])
    if len(cost) == 0:
        print("no valid combinations found")
        return "", minimal_binary_table_string
    min_cost = min(cost)
    print("min cost combinations")
    cost_counter = collections.Counter()
    min_cost_strings = list()
    min_cost_assignments = dict()
    for ab, asub, c in zip(assigned_binaries, assigned_subsets, cost):
        if c == min_cost:
            # print(len(ab), np.array(convert_subset(ab, all_binaries)), [(np.array(convert_subset(ss, all_states)), np.array(convert_subset(ss, outcomes))) for ss in convert_subset(asub, all_subsets)], c)
            ss_outcomes = [np.array(convert_subset(ss, outcomes)) for ss in convert_subset(asub, all_subsets)]
            states = [np.array(convert_subset(ss, all_states)) for ss in convert_subset(asub, all_subsets)]
            state_outcomes = [[s.tolist() + [o] for s, o in zip(state, outcome)]  for state, outcome in zip(states, ss_outcomes)]
            subset_outcome_strings = [[",".join([str(s) for s in so]) for so in soc] for soc in state_outcomes]
            binaries = np.array(convert_subset(ab, all_binaries))
            bin_so_strings = sum([["".join([str(b) for b in bn]) + "," + ss for ss in ssos] for bn, ssos in zip(binaries, subset_outcome_strings)], start=list())
            min_cost_strings.append(";".join(bin_so_strings))
            for bn, sso in zip(binaries, state_outcomes): # for each table
                for ss in sso:
                    if tuple(ss) not in min_cost_assignments:
                        min_cost_assignments[tuple(ss)] = set([tuple(bn)])
                    else:
                        min_cost_assignments[tuple(ss)].add(tuple(bn))
        cost_counter[c] += 1
    costs = [i for i in cost_counter.items()]
    costs.sort(key=lambda x: x[0])
    print("num per cost", costs)
    print("min cost assignments", min_cost_assignments)
    # print(min_cost_strings)
    if len(save_path) > 0:
        with open(save_path, 'w') as f:
            for strv in min_cost_strings:
                f.write(strv + "\n")
        with open(save_path[:-4] + "_assign.txt", 'w') as f:
            for key, value in min_cost_assignments.items():
                f.write(str(key) + "," + str(value) + "\n")
        with open(save_path[:-4] + "_hist.txt", 'w') as f:
            for key, value in costs:
                f.write(str(key) + "," + str(value) + "\n")
        with open(save_path[:-4] + "_bin_table.txt", 'w') as f:
            f.write(minimal_binary_table_string)
    return min_cost_strings, minimal_binary_table_string



def compute_normality_binaries(environment):
    all_binaries = np.array(np.meshgrid(*[[0,1] for i in range(environment.num_objects)])).T.reshape(-1,environment.num_objects)
    use_zero = environment.use_zero
    if not use_zero: all_binaries = all_binaries[1:]
    all_states = environment.all_states
    outcomes = environment.outcomes
    passive_mask = environment.passive_mask
    all_subsets = get_all_subsets(len(all_states))
    
    # assigns each binary-state pair with a splitting value, indexed by binary
    compatibility = binary_state_compatibility(all_binaries, all_states, outcomes, environment)

    # Find a minimum cost valid binary for each state
    


def convert_subset(subset, all_subsets, sort = False):
    if sort: subset.sort()
    return [all_subsets[s] for s in subset]

def compute_min_max(binary_string, minimal_binary_table_string):
    # First, read the strings into dictionaries
    binary_mapping = {",".join(row.split(",")[1:]): list() for row in minimal_binary_table_string.split(";")} # the binaries assigned to a given state for all tables
    min_mapping = {",".join(row.split(",")[1:]): len(row.split(",")[1:]) for row in minimal_binary_table_string.split(";")} # minimum size of the binary for a given state over all tables
    for table in binary_string:
        for row in table.split(";"):
            values = row.split(",")
            binary_size = sum([int(i) for i in values[0]])
            state = ",".join(values[1:])
            if state in min_mapping:
                min_mapping[state] = min(binary_size, min_mapping[state])
                binary_mapping[state].append(values[0])
            else:
                min_mapping[state] = binary_size
                binary_mapping[state] = [values[0]]
    single_binary_mapping = dict() # the binaries assigned to a given state for non-global definitions
    for row in minimal_binary_table_string.split(";"):
        values = row.split(",")
        binaries = [[int(i) for i in bs.split(" ")] for bs in values[0].split("|")] if len(values[0]) > 0 else list()
        state = ",".join(values[1:])
        single_binary_mapping[state] = binaries
    
    # For every state, figure out if there are any strict subsets that are
    # weak actual causes
    strict_subsets = list()
    for state in single_binary_mapping.keys():
        bin_subsets = list()
        # First, filter the binaries to disjoint subsets
        for bin in binary_mapping[state]:
            no_replace = list()
            disjoint = True
            for i, rem_bin in enumerate(bin_subsets):
                strict_subset = check_subset(bin, rem_bin)
                disjoint = disjoint and check_disjoint(bin, rem_bin)
                if not strict_subset: 
                    no_replace.append(i)
            bin_subsets = np.array(bin_subsets)[no_replace].tolist()
            if disjoint or len(no_replace) > 0:
                bin_subsets.append(bin)

        # now see if there are any strict subsets comparing to the single binary mapping
        for bin in bin_subsets:
            for sbin in single_binary_mapping[state]:
                # print(bin,sbin)
                strict_subset = check_subset(bin, sbin)
                if strict_subset:
                    strict_subsets.append((state, bin, sbin))
        # print("state", state, binary_mapping[state], bin_subsets, strict_subsets)

    if len(list(min_mapping.values())) > 0: max_min_cost = max(list(min_mapping.values()))
    else: max_min_cost = 0
    print("strict subsets", strict_subsets)
    return {"strict subsets": strict_subsets, "min max cost": max_min_cost, "min mapping": list(min_mapping.items())} 